{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2022, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}

{%- import "../util/util.jinja" as util with context -%}

#pragma once

#include <ostream>
#include <random>
#include <vector>
#include <Eigen/Core>
{# If a Rot type, include Geometry for Eigen::Quaternion #}
{% if cls.__name__.startswith('Rot') %}
#include <Eigen/Geometry>
{% endif %}

{# If a pose type, include the necessary rotation type. #}
{% if cls.__name__.startswith('Pose') %}
#include <sym/rot{{ cls.__name__[-1] }}.h>
{% elif cls.__name__ == 'Unit3' %}
#include <sym/rot3.h>
{% endif -%}

#include <sym/ops/storage_ops.h>
#include <sym/ops/group_ops.h>
#include <sym/ops/lie_group_ops.h>
#include <sym/util/epsilon.h>

namespace sym {

{% if doc %}
/**
 * Autogenerated C++ implementation of `{{ cls.__module__ }}.{{ cls.__qualname__ }}`.
 *
{% for line in doc.split('\n') %}
 *{{ ' {}'.format(line).rstrip() }}
{% endfor %}
 */
{% endif %}
template <typename ScalarType>
class {{ cls.__name__ }} {
 public:
  // Typedefs
  using Scalar = ScalarType;
  using Self = {{ cls.__name__ }}<Scalar>;
  using DataVec = Eigen::Matrix<Scalar, {{ ops.StorageOps.storage_dim(cls) }}, 1, Eigen::DontAlign>;
  {% if is_manifold %}
  using TangentVec = Eigen::Matrix<Scalar, {{ ops.LieGroupOps.tangent_dim(cls) }}, 1>;
  using SelfJacobian = Eigen::Matrix<Scalar, {{ ops.LieGroupOps.tangent_dim(cls) }}, {{ ops.LieGroupOps.tangent_dim(cls) }}>;
  {% endif %}

  /**
   * Construct from data vec
   *
   * @param normalize Project to the manifold on construction.  This ensures numerical stability as
   *     this constructor is called after each codegen operation.  Constructing from a normalized
   *     vector may be faster, e.g. with `FromStorage`.
   */
  explicit {{ cls.__name__ }}(const DataVec& data, const bool normalize = true)
  {% if cls in (sf.Rot2, sf.Rot3, sf.Unit3) %}
      : data_(normalize ? DataVec(data.normalized()) : data) {}
  {% elif cls == sf.Pose3 %}
      : data_(data) { if (normalize) { data_.template head<4>() = data_.template head<4>().normalized();} }
  {% elif cls == sf.Pose2 %}
      : data_(data) { if (normalize) { data_.template head<2>() = data_.template head<2>().normalized();} }
  {% else %}
      : data_(data) { (void)normalize; }
  {% endif %}

  {% if is_group %}
  // Default construct to identity
  {{ cls.__name__ }}() : {{ cls.__name__ }}(GroupOps<Self>::Identity()) {}
  {% endif %}

  // Access underlying storage as const
  inline const DataVec& Data() const {
      return data_;
  }

  {% if matrix_type_aliases %}
  // Matrix type aliases
  {% endif %}
  {% for alias in matrix_type_aliases.items() %}
  using {{ alias[1] }} = {{ alias[0] }};
  {% endfor %}

  {% set custom_template_name = "custom_methods/{}.h.jinja".format(cls.__name__.lower()) %}
  // --------------------------------------------------------------------------
  // Handwritten methods included from "{{ custom_template_name }}"
  // --------------------------------------------------------------------------

  {% include custom_template_name %}

  {% if custom_generated_methods %}
  // --------------------------------------------------------------------------
  // Custom generated methods
  // --------------------------------------------------------------------------

  {% endif %}
  {% set compose_with_point = namespace(defined=false) %}
  {% for spec in custom_generated_methods %}
    {% if spec.name == "compose_with_point" %}
      {% set compose_with_point.defined = true %}
      {% set compose_with_point.dimension = typing_util.get_type(spec.inputs["right"]).SHAPE[0] %}
    {% endif %}
    {# Return values from methods are const - https://github.com/symforce-org/symforce/issues/312 #}
  {{ util.print_docstring(spec.docstring) | indent(2) }}
  const {{ python_util.str_replace_all(util.method_declaration(spec, is_declaration=True), matrix_type_aliases) }};

  {% endfor %}
  // --------------------------------------------------------------------------
  // StorageOps concept
  // --------------------------------------------------------------------------

  static constexpr int32_t StorageDim() {
    return StorageOps<Self>::StorageDim();
  }

  void ToStorage(Scalar* const vec) const {
    return StorageOps<Self>::ToStorage(*this, vec);
  }

  static {{ cls.__name__ }} FromStorage(const Scalar* const vec) {
    return StorageOps<Self>::FromStorage(vec);
  }

  {% if compose_with_point.defined %}
  Vector{{ compose_with_point.dimension }} Compose(const Vector{{ compose_with_point.dimension }}& point) const {
    return ComposeWithPoint(point);
  }
  {% endif %}

  {% if is_group %}
  // --------------------------------------------------------------------------
  // GroupOps concept
  // --------------------------------------------------------------------------

  static Self Identity() {
    return GroupOps<Self>::Identity();
  }

  Self Inverse() const {
    return GroupOps<Self>::Inverse(*this);
  }

  Self Compose(const Self& b) const {
    return GroupOps<Self>::Compose(*this, b);
  }

  Self Between(const Self& b) const {
    return GroupOps<Self>::Between(*this, b);
  }

  Self InverseWithJacobian(SelfJacobian* const res_D_a = nullptr) const {
    return GroupOps<Self>::InverseWithJacobian(*this, res_D_a);
  }

  Self ComposeWithJacobians(const Self& b, SelfJacobian* const res_D_a = nullptr,
                            SelfJacobian* const res_D_b = nullptr) const {
    return GroupOps<Self>::ComposeWithJacobians(*this, b, res_D_a, res_D_b);
  }

  Self BetweenWithJacobians(const Self& b, SelfJacobian* const res_D_a = nullptr,
                            SelfJacobian* const res_D_b = nullptr) const {
    return GroupOps<Self>::BetweenWithJacobians(*this, b, res_D_a, res_D_b);
  }

  // Compose shorthand
  template <typename Other>
  auto operator*(const Other& b) const -> decltype(Compose(b)) {
    return Compose(b);
  }
  {% endif %}

  {% if is_manifold %}
  // --------------------------------------------------------------------------
  // LieGroupOps concept
  // --------------------------------------------------------------------------

  static constexpr int32_t TangentDim() {
    return LieGroupOps<Self>::TangentDim();
  }

  {% if is_lie_group %}
  static Self FromTangent(const TangentVec& vec, const Scalar epsilon = kDefaultEpsilon<Scalar>) {
    return LieGroupOps<Self>::FromTangent(vec, epsilon);
  }

  TangentVec ToTangent(const Scalar epsilon = kDefaultEpsilon<Scalar>) const {
    return LieGroupOps<Self>::ToTangent(*this, epsilon);
  }

  {% endif %}
  Self Retract(const TangentVec& vec, const Scalar epsilon = kDefaultEpsilon<Scalar>) const {
    return LieGroupOps<Self>::Retract(*this, vec, epsilon);
  }

  TangentVec LocalCoordinates(const Self& b, const Scalar epsilon = kDefaultEpsilon<Scalar>) const {
    return LieGroupOps<Self>::LocalCoordinates(*this, b, epsilon);
  }

  Self Interpolate(const Self b, const Scalar alpha, const Scalar epsilon = kDefaultEpsilon<Scalar>) const {
    return LieGroupOps<Self>::Interpolate(*this, b, alpha, epsilon);
  }
  {% endif %}

  // --------------------------------------------------------------------------
  // General Helpers
  // --------------------------------------------------------------------------

  bool IsApprox(const Self& b, const Scalar tol) const {
    // isApprox is multiplicative so we check the norm for the exact zero case
    // https://eigen.tuxfamily.org/dox/classEigen_1_1DenseBase.html#ae8443357b808cd393be1b51974213f9c
    if (b.Data() == DataVec::Zero()) {
      return Data().norm() < tol;
    }

    return Data().isApprox(b.Data(), tol);
  }

  template <typename ToScalar>
  {{ cls.__name__ }}<ToScalar> Cast() const {
    return {{ cls.__name__ }}<ToScalar>(Data().template cast<ToScalar>());
  }

  bool operator==(const {{ cls.__name__ }}& rhs) const {
    return data_ == rhs.Data();
  }

  bool operator!=(const {{ cls.__name__ }}& rhs) const {
    return !(*this == rhs);
  }

 protected:
  DataVec data_;
};

// Shorthand for scalar types
{% for scalar in scalar_types %}
using {{ cls.__name__ }}{{ scalar[0] }} = {{ cls.__name__ }}<{{ scalar }}>;
{% endfor %}

// Print definitions
{% for scalar in scalar_types %}
std::ostream& operator<<(std::ostream& os, const {{ cls.__name__ }}<{{ scalar }}>& a);
{% endfor %}

}  // namespace sym

// Externs to reduce duplicate instantiation
{% for scalar in scalar_types %}
extern template class sym::{{ cls.__name__ }}<{{ scalar }}>;
{% endfor %}

{% for scalar in scalar_types %}
static_assert(
  sizeof(sym::{{ cls.__name__ }}<{{ scalar }}>)
  == {{ ops.StorageOps.storage_dim(cls) }} * sizeof({{ scalar }})
);
static_assert(
  alignof(sym::{{ cls.__name__ }}<{{ scalar }}>) == sizeof({{ scalar }})
);
{% endfor %}

// Concept implementations for this class
#include "./ops/{{ camelcase_to_snakecase(cls.__name__) }}/storage_ops.h"
#include "./ops/{{ camelcase_to_snakecase(cls.__name__) }}/lie_group_ops.h"
#include "./ops/{{ camelcase_to_snakecase(cls.__name__) }}/group_ops.h"
